# Copyright (c) 2020 Trail of Bits, Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import os
import sys
import argparse
import logging
from collections import defaultdict


cc_path = os.environ['CLANG_EXE']
assert cc_path, "Please specify environment variable 'CLANG_EXE' with path to clang executable"

try:
  import ccsyspath
  syspath = ccsyspath.system_include_paths(cc_path)
  print(syspath)
except ImportError:
  syspath = list()


SUPPORTED_ARCH = ["x86", "amd64"]

SUPPORTED_LIBRARY_TYPE = ["c", "cpp"]

ARCH_NAME = ""

ABI_LIBRARY_TYPE = "c"

logging.basicConfig(filename="debug.log",level=logging.DEBUG)

cc_pragma = """

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wdeprecated"
#pragma clang diagnostic ignored "-Wdeprecated-declarations"

"""

cc_header = """
  // mcsema ABI library, automatically generated by generate_abi_wrapper.py

extern char *gets(char *s); 

__attribute__((used))
void *__mcsema_externs[] = {
"""

FUNCDECL_LIST = defaultdict(list)

FUNCDECL_MANGLED_NAME = defaultdict(list)

LOCAL_HEADERS = []

# Check if the file exist at the given paths
def file_exist(all_dirs, file):
  """ Check if file exist.
  """
  for dir in all_dirs:
    path = str(dir) + '/' + str(file)
    if os.path.exists(path):
      return True
  return os.path.exists(str(file))

# Process the function types and remove the `__attribute__((...))` identifier
# from the function types
def process_function_types(type_string):
  """ Pre-process the function types for the Funcdecl
  """
  split_string = type_string.split(' ')
  return ' '.join(str for str in split_string if '__attribute__' not in str)

def get_function_pointer(type_string):
  """ convert the function types to the pointer type
  """
  return type_string[0:type_string.find('(')-1] + " (*)" + type_string[type_string.find('('):]

def is_valid_type(type_string):
  if "_Complex" in type_string or 'typeof' in type_string:
    return False
  else:
    return True

def is_blacklisted_func(func_name):
  if 'operator' in func_name:
    return True
  return False

def visit_func_decl(node):
  """ Visit the function decl node and create a map of
      function name with the mangled name
  """
  try:
    from clang.cindex import CursorKind, TypeKind

  except ImportError:
    return
    
  if node.kind == CursorKind.FUNCTION_DECL:
    func_name = node.spelling
    mangled_name = node.mangled_name
    if not is_blacklisted_func(func_name):
      func_type = process_function_types(node.type.spelling)
      if is_valid_type(func_type):
        FUNCDECL_LIST[func_name].append([mangled_name, get_function_pointer(func_type), node.location])
      else:
        FUNCDECL_LIST[func_name].append([mangled_name, 'void *', node.location])

  for i in node.get_children():
    visit_func_decl(i)


def write_cc_file(hfile, outfile):
  """ Generate ABI library source for the c headers; 
  """
  basename = os.path.splitext(hfile)[0]

  # generate the abi lib cc file
  with open(outfile, "w") as s:
    s.write(cc_pragma)
    s.write("\n\n")
    s.write("#include \"{}\"".format(hfile))
    s.write("\n\n")
    s.write(cc_header)
    s.write("\n")
    for key in FUNCDECL_LIST.keys():
      type_values = FUNCDECL_LIST[key]
      for type in type_values:
        s.write("  //{}\n".format(repr(type[2])))
        s.write("  (void *)({}),\n".format(key))
    s.write("};\n")
    print("Number of functions: {}".format(len(FUNCDECL_LIST)))
    
def write_cxx_file(hfile, outfile):
  """ Generate ABI library source for the c headers; 
  """
  basename = os.path.splitext(hfile)[0]

  # generate the abi lib cc file
  with open(outfile, "w") as s:
    s.write(cc_pragma)
    s.write("\n\n")
    s.write("#include \"{}\"".format(hfile))
    s.write("\n\n")
    s.write(cc_header)
    s.write("\n")
    for key in FUNCDECL_LIST.iterkeys():
      key_values = FUNCDECL_LIST[key]
      for value in key_values:
        s.write("  //{} {}\n".format(repr(value[2]), value[0]))
        # get the mangled name
        s.write("//  (void *)({}),\n".format(key))
    s.write("};\n")
    print("Number of functions: {}".format(len(FUNCDECL_LIST)))

def write_library_file(hfile, outfile):
  """ Generate the library files """
  try:
    import clang.cindex
    cc_index = clang.cindex.Index.create()
    libc_type = 'c++' if ABI_LIBRARY_TYPE == "cpp" else 'c'
    if ARCH_NAME.lower() == 'amd64'.lower():
      tu = cc_index.parse(hfile, args=['-x', libc_type, '-m64'])

    elif ARCH_NAME.lower() == 'x86'.lower():
      tu = cc_index.parse(hfile, args=['-x', libc_type, '-m32'])

    else:
      print("Unsupported architecture")

    visit_func_decl(tu.cursor)

  except ImportError:
    libc_type = 'c++' if ABI_LIBRARY_TYPE == "cpp" else 'c'
    pass

  if libc_type == 'c':
    write_cc_file(hfile, outfile)
  elif libc_type == 'c++':
    write_cxx_file(hfile, outfile)

def write_header_file(file, headers):
  basename = os.path.splitext(file)
  gen_filename = basename[0] + ".h"
  print(gen_filename)
  print(headers)
  with open(gen_filename, "w") as s:
    s.write("\n")
    s.write("// {}\n".format(cc_path))
    s.write("#ifndef {}_H\n".format(os.path.basename(basename[0]).upper()))
    s.write("#define {}_H\n".format(os.path.basename(basename[0]).upper()))
    s.write("""
#ifndef __has_include
#  define __has_include(x) 1
#endif

#define _GNU_SOURCE 1
#define _REGEX_RE_COMP
#define _BSD_SOURCE 1

""")
    s.write("\n")
    for entry in headers:
      s.write("#if __has_include(<{}>)\n".format(entry))
      s.write("#  include <{}>\n".format(entry))
      s.write("#endif\n")
    for entry in LOCAL_HEADERS:
      s.write("#if __has_include(\"{}\")\n".format(entry))
      s.write("#  include \"{}\"\n".format(entry))
      s.write("#endif\n")
    s.write("\n#endif\n")
    s.flush()
    return gen_filename

  
def parse_headers(infile, outfile):
  header_files = set()
  with open(infile, "rb") as f:
    headers = f.readlines()
    headers = [x.strip() for x in headers if x.startswith(b"#include")]
    header_files = [x[x.find(b"<")+1:x.find(b">")] for x in headers if x != ""]
    for entry in headers:
      if len(entry.split(b"\"")) > 1:
        LOCAL_HEADERS.append(entry.split(b"\"")[1])
    header_files = [x for x in header_files if file_exist(syspath, x)] 
    hfile = write_header_file(infile, header_files)
    write_library_file(hfile, outfile)


if __name__ == "__main__":
  parser = argparse.ArgumentParser()
  
  parser.add_argument(
    '--arch',
    help='Name of the architecture.',
    required=True)
  
  parser.add_argument(
    '--type',
    help='ABI Library types c/c++.',
    required=True)

  parser.add_argument(
    "--input",
    help="The input pre-processed header file",
    required=True)
  
  parser.add_argument(
    "--output",
    help="The output file generated with the script",
    required=True)
  
  args = parser.parse_args(args=sys.argv[1:])
  
  ARCH_NAME = args.arch
  if ARCH_NAME not in SUPPORTED_ARCH:
    logger.debug("Arch {} is not supported!".format(args.arch))
    
  ABI_LIBRARY_TYPE = args.type
  if ABI_LIBRARY_TYPE not in SUPPORTED_LIBRARY_TYPE:
    logger.debug("Library type {} not supported!".format(args.type))

  syspath.append(os.path.dirname(os.path.abspath(args.input)))
  parse_headers(args.input, args.output)
